import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
import numpy as np

class Loss:
    def reparameterize(self, mu, std):
        eps = torch.randn_like(std)
        return mu + eps * std

    @staticmethod
    def bce_loss(output, target):
        """
        Binary Cross-Entropy Loss with NaN handling.
        NaN values in target are ignored during loss computation.
        """
        # # 忽略 output 为 0 的位置
        # output_mask = (output != 0)
        # mask = output_mask
        # # 筛选有效值
        # output = output[mask]
        # target = target[mask].view(-1)  # 展平目标值
        # # Compute BCE loss
        criterion = nn.BCELoss()
        return criterion(output, target)

    @staticmethod
    def mse_loss(X, X_hat):
        """
        Mean Squared Error Loss with NaN handling.
        NaN values in X are ignored during loss computation.
        """
        # mask = (X != 0)
        # # 筛选有效值
        # # Apply mask to X and X_hat
        # X = X[mask]
        # X_hat = X_hat[mask]
        # Compute MSE loss
        return nn.MSELoss()(X, X_hat)

    def kl_loss(self, mu, sigma):
        """
            计算 KL 损失。

            参数：
            - mu: torch.Tensor，编码器输出的均值 (batch_size, latent_dim)。
            - var: torch.Tensor，编码器输出的方差 (batch_size, latent_dim)。

            返回：
            - kl_loss: KL 损失，标量。
            """
        return 0.5 * torch.sum(mu.pow(2) + sigma.pow(2) - 1 - torch.log(sigma.pow(2))) / mu.size(1)


class Inference:
    def __init__(self, contextual_encoder, decoder, device):
        """
        初始化推理模块。
        :param contextual_encoder: 已训练的上下文编码器模型
        :param decoder: 已训练的解码器模型
        :param device: 计算设备 (cpu 或 cuda)
        """
        self.contextual_encoder = contextual_encoder.to(device)
        self.decoder = decoder.to(device)
        self.device = device

    def generate_wave_field(self, wave_buoy_data):
        """
        生成波场数据。
        :param wave_buoy_data: 浮标观测数据，形状为 (batch_size, channels, height, width)
        :param num_samples: 每个输入生成的波场样本数量
        :return: 生成的波场数据，形状为 (batch_size * num_samples, channels, height, width)
        """
        self.contextual_encoder.eval()
        self.decoder.eval()

        with torch.no_grad():
            # 获取上下文向量 c
            wave_buoy_data = wave_buoy_data.to(self.device)
            context = self.contextual_encoder(wave_buoy_data)  # 输出形状为 (batch_size, context_dim)

            # 生成 z 并解码
            batch_size = wave_buoy_data.size(0)
            wave_fields = []
            num_samples = 1
            for _ in range(num_samples):
                # 从标准正态分布 N(0, I) 中采样 z, 这是随机值，跟前面输入均值和方差的采样做区分
                z = torch.randn_like(context, device=self.device)  # z 形状与 context 相同
                # 解码生成波场
                generated_wave_field = self.decoder(z, context)  # 形状为 (batch_size, channels, height, width)
                wave_fields.append(generated_wave_field)

            # 合并生成的波场样本
            wave_fields = torch.cat(wave_fields, dim=0)  # 合并生成的样本
            return wave_fields

class EvaluationMetrics:
    def __init__(self, predictions, ground_truth):
        """
        初始化函数，接收预测值和真实值数据。

        参数：
        - predictions: numpy数组，模型的预测值，形状假设为 [num_examples, feature_dim, grid_height, grid_width]，这里简化先以二维理解为 [num_examples, num_grid_cells]。
        - ground_truth: numpy数组，对应的真实值，形状与预测值相同。
        """
        self.predictions = predictions
        self.ground_truth = ground_truth
        self.num_examples, self.num_grid_cells = self.predictions.shape[:2]

    def _flatten_data(self, data):
        """
        将多维数据展平，方便后续统一计算（假设最后两维看作网格单元维度，展平这些维度）。

        参数：
        - data: numpy数组，要展平的数据。

        返回：
        - 展平后的一维数组。
        """
        return data.reshape(self.num_examples * self.num_grid_cells)

    def root_mean_square_error(self):
        """
        计算均方根误差（RMSE）。

        返回：
        - RMSE值。
        """
        flattened_predictions = self._flatten_data(self.predictions)
        flattened_ground_truth = self._flatten_data(self.ground_truth)
        diff = flattened_predictions - flattened_ground_truth
        squared_diff = diff ** 2
        mean_squared_diff = np.mean(squared_diff)
        return np.sqrt(mean_squared_diff)

    def _compute_crps_per_grid_cell(self, predictions_per_cell, ground_truth_per_cell):
        """
        计算单个网格单元的连续排序概率评分（CRPS）。

        参数：
        - predictions_per_cell: 对应单个网格单元的预测值数组（这里假设已经是从预测分布中抽取的样本形式）。
        - ground_truth_per_cell: 对应单个网格单元的真实值。

        返回：
        - 单个网格单元的CRPS值。
        """
        num_samples = predictions_per_cell.shape[0]
        # 计算 E|P - Oi|
        abs_diff = np.abs(predictions_per_cell - ground_truth_per_cell)
        expected_abs_diff = np.mean(abs_diff)
        # 计算 E|P - P'|
        diff_between_samples = np.abs(predictions_per_cell[:, np.newaxis] - predictions_per_cell[np.newaxis, :])
        expected_diff_between_samples = np.mean(diff_between_samples) / 2
        return expected_abs_diff - expected_diff_between_samples

    def continuous_ranked_probability_score(self):
        """
        计算连续排序概率评分（CRPS）。

        返回：
        - CRPS值。
        """
        flattened_predictions = self._flatten_data(self.predictions)
        flattened_ground_truth = self._flatten_data(self.ground_truth)
        predictions_reshaped = flattened_predictions.reshape(-1, 10)  # 假设使用10个集合成员作为样本，按样本维度重塑
        ground_truth_reshaped = flattened_ground_truth.reshape(-1, 1)
        crps_per_cell = np.array([self._compute_crps_per_grid_cell(predictions_reshaped[i], ground_truth_reshaped[i])
                                  for i in range(self.num_examples * self.num_grid_cells)])
        return np.mean(crps_per_cell)


class TimeSeriesDataset(Dataset):
    def __init__(self, A, B, batch_size, time_step_A=3, time_step_B_ratio=3):
        """
        自定义数据集。

        参数：
        - A: torch.Tensor，形状为 (5, T, 4) 的时间序列 A。
        - B: torch.Tensor，形状为 (T, 4, 128, 128) 的时间序列 B。
        - batch_size: int，批次大小。
        - time_step_A: int，每次从 A 读取的时间步长。
        - time_step_B_ratio: int，从 B 读取的时间步与 A 的时间步长比率。
        """
        self.A = A
        self.B = B
        self.batch_size = batch_size
        self.time_step_A = time_step_A
        self.time_step_B_ratio = time_step_B_ratio
        self.num_samples = A.shape[1]  # A 的时间步数

    def __len__(self):
        # 计算总批次数
        total_batches = (self.num_samples + self.time_step_A - 1) // self.time_step_A
        return (total_batches + self.batch_size - 1) // self.batch_size

    def __getitem__(self, idx):
        # 每个批次包含的时间段索引范围
        batch_start_idx = idx * self.batch_size * self.time_step_A
        batch_end_idx = min((idx + 1) * self.batch_size * self.time_step_A, self.num_samples)

        # 按时间步分段获取 A 和 B 的时间片
        A_batches = []
        B_batches = []

        for start_idx in range(batch_start_idx, batch_end_idx, self.time_step_A):
            end_idx = min(start_idx + self.time_step_A, self.num_samples)
            A_segment = self.A[:, start_idx:end_idx, :]  # (5, time_step_A, 4)
            B_segment = self.B[start_idx // self.time_step_B_ratio:end_idx // self.time_step_B_ratio]  # 对应 B 的时间步
            A_batches.append(A_segment)
            B_batches.append(B_segment)

        # 合并为 batch 维度
        A_batches = torch.stack(A_batches, dim=0)  # (batch_size, 5, time_step_A, 4)
        B_batches = torch.stack(B_batches, dim=0).squeeze(1)  # (batch_size, 4, 128, 128)

        return A_batches, B_batches


class BuoySeriesDataset(Dataset):
    def __init__(self, A, batch_size, time_step_A=3):
        """
        自定义数据集。

        参数：
        - A: torch.Tensor，形状为 (5, T, 4) 的时间序列 A。
        - batch_size: int，批次大小。
        - time_step_A: int，每次从 A 读取的时间步长。
        """
        self.A = A
        self.batch_size = batch_size
        self.time_step_A = time_step_A
        self.num_samples = A.shape[1]  # A 的时间步数

    def __len__(self):
        # 计算总批次数
        total_batches = (self.num_samples + self.time_step_A - 1) // self.time_step_A
        return (total_batches + self.batch_size - 1) // self.batch_size

    def __getitem__(self, idx):
        # 每个批次包含的时间段索引范围
        batch_start_idx = idx * self.batch_size * self.time_step_A
        batch_end_idx = min((idx + 1) * self.batch_size * self.time_step_A, self.num_samples)

        # 按时间步分段获取 A 和 B 的时间片
        A_batches = []
        for start_idx in range(batch_start_idx, batch_end_idx, self.time_step_A):
            end_idx = min(start_idx + self.time_step_A, self.num_samples)
            A_segment = self.A[:, start_idx:end_idx, :]  # (5, time_step_A, 4)
            A_batches.append(A_segment)

        # 合并为 batch 维度
        A_batches = torch.stack(A_batches, dim=0)  # (batch_size, 5, time_step_A, 4)
        return A_batches
